import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import abstractmethod, ABCMeta
from .linear import *


class Basic(nn.Module, metaclass=ABCMeta):
    def __init__(self, num_feats, num_classes, mode='linear', need_linear=True, weight=None):
        super(Basic, self).__init__()
        self.need_linear = need_linear
        if self.need_linear:
            if mode =='norm':
                self.linear = NormLinear(num_feats, num_classes)
            elif mode == 'fixnorm':
                self.linear = FixNormLinear(num_feats, num_classes, weight)
            elif mode == 'fixlinear':
                self.linear = FixLinear(num_feats, num_classes, weight)
            else:
                self.linear = nn.Linear(num_feats, num_classes, bias=False)
    
    @abstractmethod
    def get_feature(self, x):
        """
            Get the input of the last fully-connect layer
        """
    
    def _reset_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
            elif isinstance(m, nn.Linear) or isinstance(m, NormLinear):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
    
    def get_margin(self):
        """Get the class margin of the last fully-connected layer. """
        weight = self.get_weight()
        norm = torch.sqrt(torch.sum(weight ** 2, dim=1))
        ratio = torch.max(norm) / torch.min(norm)

        tmp = F.normalize(weight, dim=1)
        similarity = torch.matmul(tmp, tmp.transpose(1, 0)) - 2 * torch.eye(tmp.size(0), device=weight.device)
        return torch.acos(torch.max(similarity)).item() / math.pi * 180, torch.min(norm).item(), ratio.item()

    def get_weight(self):
        """Get the weight of the last fully-connected layer. """
        return self.linear.weight
    
    def forward(self, x, adjusted=False, need_feat=False):
        feat = self.get_feature(x)
        if self.need_linear:
            logits = self.linear(feat)
            if adjusted:
                feat_norm = torch.norm(feat, dim=1).view(-1, 1).repeat(1, logits.size(1))
                logits = feat_norm.detach() * logits
            if need_feat:
                return logits, feat
            else:
                return logits
        return feat